# -*- coding: utf-8 -*-
"""
Created on Thu Sep 26 10:21:16 2024

@author: xinyuan.wei
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score

# Get the current working directory
current_directory = os.getcwd()

# Get the parent directory
parent_directory = os.path.dirname(current_directory)
file_path = os.path.join(parent_directory, 'NE_Plot_Location.csv')

# Read the CSV file
data = pd.read_csv(file_path)

# Filter data where 'dominant_area_species' == 316
filtered_data = data[data['dominant_area_species'] == 833]

# Extract 'STDAGE' and 'eAG' columns
X = filtered_data['STDAGE'].values.reshape(-1, 1)
y = filtered_data['eAG'].values

# Linear Regression
linear_model = LinearRegression()
linear_model.fit(X, y)
y_pred_linear = linear_model.predict(X)
r2_linear = r2_score(y, y_pred_linear)
m = linear_model.coef_[0]
c = linear_model.intercept_
print(f'Linear Regression Equation: y = {m:.2f} * x + {c:.2f}')

# Exponential Regression
def exponential_func(x, a, b):
    return a * np.exp(b * x)

# Initial guess for the parameters
initial_guess = [1, 0.01]

# Perform curve fitting
popt, pcov = curve_fit(exponential_func, X.ravel(), y, p0=initial_guess)
y_pred_exp = exponential_func(X.ravel(), *popt)
r2_exponential = r2_score(y, y_pred_exp)
a, b = popt
print(f'Exponential Regression Equation: y = {a:.2f} * exp({b:.2f} * x)')

# Plotting
plt.figure(figsize=(10, 6))
plt.scatter(X, y, label='Data', color='blue')

# Plot linear regression line
plt.plot(X, y_pred_linear, label=f'Linear Fit (R² = {r2_linear:.2f})', color='red')

# Plot exponential regression curve
plt.plot(X, y_pred_exp, label=f'Exponential Fit (R² = {r2_exponential:.2f})', color='green')

# Annotate equations on the plot
plt.text(0.05, 0.95, f'Linear: y = {m:.2f}x + {c:.2f}', transform=plt.gca().transAxes, fontsize=10, verticalalignment='top')
plt.text(0.05, 0.90, f'Exponential: y = {a:.2f}e^({b:.2f}x)', transform=plt.gca().transAxes, fontsize=10, verticalalignment='top')

plt.xlabel('STDAGE')
plt.ylabel('eAG')
plt.title('Regression Models between STDAGE and eAG')
plt.legend()
plt.grid(True)
plt.show()
